485be6
@@ -24,16 +24,21 @@
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.FunctionInfo;
 import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.optimizer.optiq.OptiqSemanticException;
 import org.apache.hadoop.hive.ql.parse.ASTNode;
 import org.apache.hadoop.hive.ql.parse.HiveParser;
 import org.apache.hadoop.hive.ql.parse.ParseDriver;
+import org.apache.hadoop.hive.ql.udf.SettableUDF;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNegative;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPositive;
+import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
+import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo;
 import org.eigenbase.reltype.RelDataType;
 import org.eigenbase.reltype.RelDataTypeFactory;
 import org.eigenbase.sql.SqlAggFunction;
@@ -96,10 +101,12 @@
private static FunctionInfo handleExplicitCast(SqlOperator op, RelDataType dt) {
 
       if (castType.equals(TypeInfoFactory.byteTypeInfo)) {
         castUDF = FunctionRegistry.getFunctionInfo("tinyint");
-      } else if (castType.equals(TypeInfoFactory.charTypeInfo)) {
-        castUDF = FunctionRegistry.getFunctionInfo("char");
-      } else if (castType.equals(TypeInfoFactory.varcharTypeInfo)) {
-        castUDF = FunctionRegistry.getFunctionInfo("varchar");
+      } else if (castType instanceof CharTypeInfo) {
+        castUDF = handleCastForParameterizedType(castType,
+          FunctionRegistry.getFunctionInfo("char"));
+      } else if (castType instanceof VarcharTypeInfo) {
+        castUDF = handleCastForParameterizedType(castType,
+          FunctionRegistry.getFunctionInfo("varchar"));
       } else if (castType.equals(TypeInfoFactory.stringTypeInfo)) {
         castUDF = FunctionRegistry.getFunctionInfo("string");
       } else if (castType.equals(TypeInfoFactory.booleanTypeInfo)) {
@@ -118,16 +125,28 @@
private static FunctionInfo handleExplicitCast(SqlOperator op, RelDataType dt) {
         castUDF = FunctionRegistry.getFunctionInfo("timestamp");
       } else if (castType.equals(TypeInfoFactory.dateTypeInfo)) {
         castUDF = FunctionRegistry.getFunctionInfo("datetime");
-      } else if (castType.equals(TypeInfoFactory.decimalTypeInfo)) {
-        castUDF = FunctionRegistry.getFunctionInfo("decimal");
+      } else if (castType instanceof DecimalTypeInfo) {
+        castUDF = handleCastForParameterizedType(castType,
+          FunctionRegistry.getFunctionInfo("decimal"));
       } else if (castType.equals(TypeInfoFactory.binaryTypeInfo)) {
         castUDF = FunctionRegistry.getFunctionInfo("binary");
-      }
+      } else throw new IllegalStateException("Unexpected type : " +
+        castType.getQualifiedName());
     }
 
     return castUDF;
   }
 
+  private static FunctionInfo handleCastForParameterizedType(TypeInfo ti, FunctionInfo fi) {
+    SettableUDF udf = (SettableUDF)fi.getGenericUDF();
+    try {
+      udf.setTypeInfo(ti);
+    } catch (UDFArgumentException e) {
+      throw new RuntimeException(e);
+    }
+    return new FunctionInfo(fi.isNative(),fi.getDisplayName(),(GenericUDF)udf);
+  }
+
   // TODO: 1) handle Agg Func Name translation 2) is it correct to add func args
   // as child of func?
   public static ASTNode buildAST(SqlOperator op, List<ASTNode> children) {
